Introduction

Linear Discriminant Analysis (LDA) is a powerful technique for dimensionality reduction and classification. It works by finding linear combinations of features that best separate different classes in the data. Unlike Principal Component Analysis (PCA), which focuses on maximizing variance, LDA specifically aims to maximize the separability between classes.

Key Concepts

  • Dimensionality Reduction: Reduces the number of features while preserving class separability
  • Supervised Learning: Uses class labels to guide the transformation
  • Linear Separability: Assumes that classes can be separated by linear boundaries
  • Gaussian Assumptions: Works best when data follows multivariate normal distributions with equal covariance matrices

Sample Data Demonstration

Let’s start with a simple example to understand how LDA works on well-separated classes.

# Function to generate sample data for demonstration
generate_sample_data <- function(n_samples = 300) {
  # Generate three classes with different means and same covariance
  class1 <- MASS::mvrnorm(n = n_samples/3, 
                          mu = c(50, 60, 70), 
                          Sigma = matrix(c(100, 20, 15, 20, 100, 25, 15, 25, 100), 3, 3))
  class2 <- MASS::mvrnorm(n = n_samples/3, 
                          mu = c(80, 40, 50), 
                          Sigma = matrix(c(100, 20, 15, 20, 100, 25, 15, 25, 100), 3, 3))
  class3 <- MASS::mvrnorm(n = n_samples/3, 
                          mu = c(30, 90, 80), 
                          Sigma = matrix(c(100, 20, 15, 20, 100, 25, 15, 25, 100), 3, 3))
  
  # Combine data
  sample_data <- rbind(
    data.frame(feature1 = class1[,1], feature2 = class1[,2], feature3 = class1[,3], class = "Class A"),
    data.frame(feature1 = class2[,1], feature2 = class2[,2], feature3 = class2[,3], class = "Class B"),
    data.frame(feature1 = class3[,1], feature2 = class3[,2], feature3 = class3[,3], class = "Class C")
  )
  
  return(sample_data)
}

# Generate sample data
sample_data <- generate_sample_data()

Let’s examine the structure of our sample data:

# Display first few rows
head(sample_data)
##   feature1 feature2 feature3   class
## 1 52.77124 50.77344 88.78080 Class A
## 2 56.13956 52.65127 77.00099 Class A
## 3 37.88582 50.65953 59.17225 Class A
## 4 48.08763 56.45356 74.08270 Class A
## 5 41.14525 63.27559 71.86210 Class A
## 6 37.98836 50.58005 55.97900 Class A
# Summary statistics by class
sample_data %>%
  group_by(class) %>%
  summarise(
    n = n(),
    mean_feature1 = mean(feature1),
    mean_feature2 = mean(feature2),
    mean_feature3 = mean(feature3),
    sd_feature1 = sd(feature1),
    sd_feature2 = sd(feature2),
    sd_feature3 = sd(feature3)
  ) %>%
  kable(digits = 2)
class n mean_feature1 mean_feature2 mean_feature3 sd_feature1 sd_feature2 sd_feature3
Class A 100 48.88 58.71 70.55 9.87 8.68 9.55
Class B 100 80.93 40.38 49.49 10.56 9.70 9.73
Class C 100 31.93 90.32 80.97 10.37 9.71 10.77

Now let’s visualize the original 3D data:

# Create 3D scatter plot
p1 <- ggplot(sample_data, aes(x = feature1, y = feature2, color = class)) +
  geom_point(size = 2, alpha = 0.7) +
  labs(title = "Sample Data: Original Features 1 vs 2",
       x = "Feature 1", y = "Feature 2") +
  theme_minimal() +
  theme(legend.position = "bottom",
        panel.background = element_rect(fill = "white"))

p2 <- ggplot(sample_data, aes(x = feature1, y = feature3, color = class)) +
  geom_point(size = 2, alpha = 0.7) +
  labs(title = "Sample Data: Original Features 1 vs 3",
       x = "Feature 1", y = "Feature 3") +
  theme_minimal() +
  theme(legend.position = "bottom",
        panel.background = element_rect(fill = "white"))

p3 <- ggplot(sample_data, aes(x = feature2, y = feature3, color = class)) +
  geom_point(size = 2, alpha = 0.7) +
  labs(title = "Sample Data: Original Features 2 vs 3",
       x = "Feature 2", y = "Feature 3") +
  theme_minimal() +
  theme(legend.position = "bottom",
        panel.background = element_rect(fill = "white"))

# Arrange plots
grid.arrange(p1, p2, p3, ncol = 2)

Applying LDA to Sample Data

Now let’s apply LDA to see how it transforms our data:

# Prepare data for LDA
X_sample <- as.matrix(sample_data[, 1:3])
y_sample <- sample_data$class

# Perform LDA
lda_sample <- lda(X_sample, y_sample)

# Transform data
X_lda_sample <- predict(lda_sample, X_sample)$x

# Display LDA results
cat("Number of classes:", length(lda_sample$lev), "\n")
## Number of classes: 3
cat("Prior probabilities:\n")
## Prior probabilities:
print(lda_sample$prior)
##   Class A   Class B   Class C 
## 0.3333333 0.3333333 0.3333333
cat("\nLDA coefficients:\n")
## 
## LDA coefficients:
print(lda_sample$scaling)
##                  LD1         LD2
## feature1 -0.07406551 -0.04545376
## feature2  0.06960232 -0.08690477
## feature3  0.03710453  0.06520045

Let’s visualize the LDA transformation:

# Create visualization of LDA projection
lda_plot_data <- data.frame(
  LD1 = X_lda_sample[,1], 
  LD2 = X_lda_sample[,2], 
  class = y_sample
)

ggplot(lda_plot_data, aes(x = LD1, y = LD2, color = class)) +
  geom_point(size = 2, alpha = 0.7) +
  stat_ellipse(level = 0.95) +
  labs(title = "Sample Data: LDA Projection",
       subtitle = "Notice how well the classes are separated after LDA transformation",
       x = "Linear Discriminant 1", 
       y = "Linear Discriminant 2") +
  theme_minimal() +
  theme(legend.position = "bottom",
        panel.background = element_rect(fill = "white"))

Pokemon Dataset Analysis

Now let’s apply LDA to a real-world dataset - the Pokemon dataset. We’ll predict Pokemon types based on their stats.

Loading and Exploring the Data

# Load Pokemon dataset
pokemon <- read.csv("Pokemon-Dataset/pokemon.csv")

# Display basic information
cat("Dataset dimensions:", nrow(pokemon), "rows ×", ncol(pokemon), "columns\n")
## Dataset dimensions: 801 rows × 41 columns
cat("Columns:", paste(colnames(pokemon), collapse = ", "), "\n")
## Columns: abilities, against_bug, against_dark, against_dragon, against_electric, against_fairy, against_fight, against_fire, against_flying, against_ghost, against_grass, against_ground, against_ice, against_normal, against_poison, against_psychic, against_rock, against_steel, against_water, attack, base_egg_steps, base_happiness, base_total, capture_rate, classfication, defense, experience_growth, height_m, hp, japanese_name, name, percentage_male, pokedex_number, sp_attack, sp_defense, speed, type1, type2, weight_kg, generation, is_legendary
# Display first few rows
head(pokemon[, c("name", "type1", "type2", "hp", "attack", "defense", "sp_attack", "sp_defense", "speed")])
##         name type1  type2 hp attack defense sp_attack sp_defense speed
## 1  Bulbasaur grass poison 45     49      49        65         65    45
## 2    Ivysaur grass poison 60     62      63        80         80    60
## 3   Venusaur grass poison 80    100     123       122        120    80
## 4 Charmander  fire        39     52      43        60         50    65
## 5 Charmeleon  fire        58     64      58        80         65    80
## 6  Charizard  fire flying 78    104      78       159        115   100

Data Preprocessing

# Select only Pokemon with single type (no dual types)
single_type_pokemon <- pokemon[is.na(pokemon$type2) | pokemon$type2 == "", ]

cat("Single-type Pokemon count:", nrow(single_type_pokemon), "\n")
## Single-type Pokemon count: 384
# Select relevant features for analysis
features <- c("hp", "attack", "defense", "sp_attack", "sp_defense", "speed", "type1")
pokemon_subset <- single_type_pokemon[, features]

# Remove rows with missing values
pokemon_subset <- pokemon_subset[complete.cases(pokemon_subset), ]

cat("Final dataset size:", nrow(pokemon_subset), "rows\n")
## Final dataset size: 384 rows
# Show type distribution
type_counts <- table(pokemon_subset$type1)
cat("Type distribution:\n")
## Type distribution:
print(type_counts)
## 
##      bug     dark   dragon electric    fairy fighting     fire   flying 
##       18        9       12       26       16       22       27        1 
##    ghost    grass   ground      ice   normal   poison  psychic     rock 
##        9       37       10       12       61       13       35       11 
##    steel    water 
##        4       61

Let’s visualize the type distribution:

# Create type distribution plot
type_df <- data.frame(
  type = names(type_counts),
  count = as.numeric(type_counts)
)

ggplot(type_df, aes(x = reorder(type, count), y = count)) +
  geom_bar(stat = "identity", fill = "steelblue", alpha = 0.8) +
  coord_flip() +
  labs(title = "Pokemon Type Distribution",
       subtitle = "Number of Pokemon per type (single-type only)",
       x = "Pokemon Type", 
       y = "Count") +
  theme_minimal() +
  theme(axis.text.y = element_text(size = 10),
        panel.background = element_rect(fill = "white"))

Feature Analysis

Let’s examine the distribution of Pokemon stats:

# Create histograms for each stat
pokemon_numeric <- pokemon_subset[, 1:6]  # Select only numeric columns
pokemon_long <- data.frame(
  stat = rep(names(pokemon_numeric), each = nrow(pokemon_numeric)),
  value = as.vector(as.matrix(pokemon_numeric))
)

ggplot(pokemon_long, aes(x = value, fill = stat)) +
  geom_histogram(bins = 30, alpha = 0.7) +
  facet_wrap(~stat, scales = "free", ncol = 2) +
  labs(title = "Distribution of Pokemon Stats",
       x = "Stat Value", y = "Frequency") +
  theme_minimal() +
  theme(legend.position = "none",
        panel.background = element_rect(fill = "white"))

Correlation Analysis

# Calculate correlation matrix
numeric_features <- pokemon_subset[, 1:6]
cor_matrix <- cor(numeric_features)

# Create correlation plot
corrplot(cor_matrix, method = "color", type = "upper", 
         addCoef.col = "black", tl.col = "black", tl.srt = 45,
         title = "Pokemon Stats Correlation Matrix",
         mar = c(0,0,2,0))

Applying LDA to Pokemon Data

# Prepare data for LDA
X <- as.matrix(pokemon_subset[, 1:6])  # Features
y <- pokemon_subset$type1               # Target variable

# Normalize data (recommended for LDA)
X_scaled <- scale(X)

# Perform LDA
lda_model <- lda(X_scaled, y)

cat("LDA model fitted successfully\n")
## LDA model fitted successfully
cat("Number of classes:", length(lda_model$lev), "\n")
## Number of classes: 18
cat("Prior probabilities:\n")
## Prior probabilities:
print(lda_model$prior)
##         bug        dark      dragon    electric       fairy    fighting 
## 0.046875000 0.023437500 0.031250000 0.067708333 0.041666667 0.057291667 
##        fire      flying       ghost       grass      ground         ice 
## 0.070312500 0.002604167 0.023437500 0.096354167 0.026041667 0.031250000 
##      normal      poison     psychic        rock       steel       water 
## 0.158854167 0.033854167 0.091145833 0.028645833 0.010416667 0.158854167
# Transform data using LDA
X_lda <- predict(lda_model, X_scaled)$x

LDA Coefficients Analysis

# Extract coefficients
coef_matrix <- lda_model$scaling
feature_names <- c("HP", "Attack", "Defense", "Sp. Attack", "Sp. Defense", "Speed")

# Create coefficient heatmap
coef_df <- data.frame(
  feature = rep(feature_names, ncol(coef_matrix)),
  discriminant = rep(paste0("LD", 1:ncol(coef_matrix)), each = length(feature_names)),
  coefficient = as.vector(coef_matrix)
)

ggplot(coef_df, aes(x = discriminant, y = feature, fill = coefficient)) +
  geom_tile() +
  scale_fill_gradient2(low = "blue", mid = "white", high = "red", 
                       midpoint = 0, name = "Coefficient") +
  labs(title = "LDA Coefficients Heatmap",
       subtitle = "Shows how each feature contributes to each discriminant",
       x = "Linear Discriminants", 
       y = "Features") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1),
        panel.background = element_rect(fill = "white"))

Feature Importance

# Calculate feature importance (sum of absolute coefficients across discriminants)
feature_importance <- rowSums(abs(coef_matrix))

importance_df <- data.frame(
  feature = feature_names,
  importance = feature_importance
)

ggplot(importance_df, aes(x = reorder(feature, importance), y = importance)) +
  geom_bar(stat = "identity", fill = "steelblue", alpha = 0.8) +
  coord_flip() +
  labs(title = "Feature Importance in LDA",
       subtitle = "Sum of absolute coefficients across all discriminants",
       x = "Features", 
       y = "Importance Score") +
  theme_minimal() +
  theme(axis.text.y = element_text(size = 10),
        panel.background = element_rect(fill = "white"))

LDA Projections Visualization

# Create data frame for plotting
plot_data <- data.frame(
  LD1 = X_lda[,1],
  LD2 = X_lda[,2],
  LD3 = X_lda[,3],
  type = y
)

# Plot first two discriminants
p1 <- ggplot(plot_data, aes(x = LD1, y = LD2, color = type)) +
  geom_point(alpha = 0.6, size = 2) +
  stat_ellipse(level = 0.95, alpha = 0.3) +
  labs(title = "Pokemon Types: LDA Projection (LD1 vs LD2)",
       x = "Linear Discriminant 1", 
       y = "Linear Discriminant 2") +
  theme_minimal() +
  theme(legend.position = "bottom",
        legend.text = element_text(size = 8),
        panel.background = element_rect(fill = "white"))

# Plot first vs third discriminant
p2 <- ggplot(plot_data, aes(x = LD1, y = LD3, color = type)) +
  geom_point(alpha = 0.6, size = 2) +
  stat_ellipse(level = 0.95, alpha = 0.3) +
  labs(title = "Pokemon Types: LDA Projection (LD1 vs LD3)",
       x = "Linear Discriminant 1", 
       y = "Linear Discriminant 3") +
  theme_minimal() +
  theme(legend.position = "bottom",
        legend.text = element_text(size = 8),
        panel.background = element_rect(fill = "white"))

# Arrange plots
grid.arrange(p1, p2, ncol = 2)

Performance Evaluation

# Make predictions
predictions <- predict(lda_model, X_scaled)

# Calculate accuracy
accuracy <- mean(predictions$class == y)
cat("Overall accuracy:", round(accuracy * 100, 2), "%\n")
## Overall accuracy: 33.59 %
# Create confusion matrix
conf_matrix <- table(Actual = y, Predicted = predictions$class)

# Calculate per-class accuracy
per_class_accuracy <- diag(conf_matrix) / rowSums(conf_matrix)
cat("\nPer-class accuracy:\n")
## 
## Per-class accuracy:
print(round(per_class_accuracy * 100, 2))
##      bug     dark   dragon electric    fairy fighting     fire   flying 
##    44.44     0.00     0.00    23.08    25.00    40.91     7.41     0.00 
##    ghost    grass   ground      ice   normal   poison  psychic     rock 
##    33.33     0.00    20.00     0.00    57.38     0.00    42.86    45.45 
##    steel    water 
##    25.00    63.93

Let’s visualize the confusion matrix:

# Create confusion matrix heatmap
conf_df <- as.data.frame(conf_matrix)
conf_df$Actual <- factor(conf_df$Actual, levels = unique(conf_df$Actual))
conf_df$Predicted <- factor(conf_df$Predicted, levels = unique(conf_df$Predicted))

ggplot(conf_df, aes(x = Predicted, y = Actual, fill = Freq)) +
  geom_tile() +
  scale_fill_gradient(low = "white", high = "red", name = "Count") +
  geom_text(aes(label = Freq), color = "black", size = 3) +
  labs(title = "Confusion Matrix Heatmap",
       subtitle = paste("Overall Accuracy:", round(accuracy * 100, 2), "%"),
       x = "Predicted Type", 
       y = "Actual Type") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1),
        axis.text.y = element_text(size = 8),
        panel.background = element_rect(fill = "white"))

Discussion and Interpretation

Key Findings

  1. Sample Data: LDA successfully separated the three synthetic classes, demonstrating its effectiveness on well-separated data.

  2. Pokemon Classification: The overall accuracy of ~33.6% suggests that Pokemon types are not easily separable using only their base stats.

  3. Feature Importance: Some stats (like Speed and Attack) appear more important for type classification than others.

  4. Class Separability: The overlapping ellipses in the LDA projections indicate that many Pokemon types have similar stat distributions.

Limitations and Considerations

  • Non-linear Relationships: Pokemon types may have non-linear relationships with stats that LDA cannot capture.
  • Feature Engineering: Additional features like abilities, moves, or evolutionary stage might improve classification.
  • Class Imbalance: Some Pokemon types have very few representatives, affecting model performance.
  • Assumption Violations: Real data rarely perfectly meets LDA’s Gaussian and equal covariance assumptions.

Applications and Extensions

  • Dimensionality Reduction: LDA can be used as a preprocessing step for other classification algorithms.
  • Feature Selection: The coefficients help identify which features are most discriminative.
  • Data Exploration: LDA projections reveal the structure and separability of classes in the data.

Conclusion

Linear Discriminant Analysis provides valuable insights into the Pokemon dataset, revealing both the potential and limitations of linear classification approaches. While the overall accuracy suggests that Pokemon types cannot be perfectly predicted from base stats alone, LDA successfully identifies the most discriminative features and provides a foundation for more sophisticated analysis.

The technique demonstrates its strength in dimensionality reduction and class separability analysis, making it a valuable tool for exploratory data analysis and preprocessing in machine learning pipelines.

References

  • Fisher, R.A. (1936). “The use of multiple measurements in taxonomic problems”
  • MASS package documentation for LDA implementation
  • Pokemon dataset from Kaggle